"""
获取LSTM轨迹数据集
"""
import os
import cv2
import math
import config
import tensorflow as tf
import math_utility as mu
import gesture_recognition_utility as gu
from CNM import point_matching


RESCALE_RATE = config.RESCALE_RATE
KERNEL_SIZE = config.KERNEL_SIZE
FAST_RADIUS = config.FAST_RADIUS
SET_MAX_LENGTH = config.SET_MAX_LENGTH
MIN_DISTANCE = config.MIN_DISTANCE
# 模型全局变量
REVISE_WIDTH = config.REVISE_WIDTH
REVISE_HEIGHT = config.REVISE_HEIGHT
MODEL_WIDTH = config.MODEL_WIDTH
MODEL_HEIGHT = config.MODEL_HEIGHT
SLOT_NUMBERS = config.SLOT_NUMBERS
# 路径全局变量
MODEL_DIR = './CNN/Model/model.meta'
OUT_FILE = './LSTM/Data/data.txt'
VIDEO_FILE_DIR = './LSTM/Trainsets_LSTM/'     # 获取视频流


# 对轨迹编码1-12
def angle_to_number(angle):
    if angle == -1:
        return angle
    number = int(angle * 12) + 1
    return number


if __name__ == '__main__':

    file_list = os.listdir(VIDEO_FILE_DIR)  # 获取视频列表
    # 遍历所有视频文件
    for file in range(len(file_list)):
        video_list = os.listdir(os.path.join(VIDEO_FILE_DIR, str(file_list[file])))
        for video in range(len(video_list)):
            video_path = os.path.join(VIDEO_FILE_DIR, str(file_list[file]), str(video_list[video]))
            video_capture = cv2.VideoCapture(video_path)
            print(video_path)
            tf.reset_default_graph()    # 重置tensorflow缓存，避免内存泄漏
            with tf.Session() as sess:
                # 导入模型
                saver = tf.train.import_meta_graph(MODEL_DIR)
                saver.restore(sess, tf.train.latest_checkpoint('./CNN/Model/'))
                # 导入计算图
                graph = tf.get_default_graph()
                x = graph.get_tensor_by_name('input/x:0')
                # 逐帧计算
                last_points = []
                frame_count = 0
                ang_list = []
                dis_list = []
                while True:
                    ret, frame = video_capture.read()
                    # 判断视频流中的帧是否存在
                    if frame is None:
                        break
                    # 强制竖屏
                    if frame.shape[1] > frame.shape[0]:
                        frame = cv2.transpose(frame)
                        frame = cv2.flip(frame, 1)
                    frame = cv2.resize(frame, (REVISE_WIDTH, REVISE_HEIGHT))
                    copy = frame
                    # 先缩小二分之一进行检测，再映射到原来的尺度空间中
                    frame = cv2.resize(frame, (int(REVISE_WIDTH / RESCALE_RATE), int(REVISE_HEIGHT / RESCALE_RATE)))
                    heat_map, point_set = gu.get_heatmap(frame)    # 获取热力图和聚类中心
                    input_list, out_points = gu.cut_image(copy, point_set)     # 对图像进行切片
                    # 导入数据集进行测试
                    feed_dict = {x: input_list}
                    logits = graph.get_tensor_by_name('logits_eval:0')
                    classification_result = sess.run(logits, feed_dict)
                    # 输出预测矩阵每一行最大值的索引
                    output = tf.argmax(classification_result, 1).eval()
                    # 判断是否为指尖点，0是negative，1是positive
                    fingertips = []
                    for i in range(len(out_points)):
                        if int(output[i]) == 1:
                            fingertips.append(out_points[i])
                    '''
                    for i in range(len(fingertips)):
                        cv2.circle(copy, tuple((fingertips[i][0] * RESCALE_RATE, fingertips[i][1] * RESCALE_RATE)), 3, (0, 0, 255), cv2.FILLED)
                    '''
                    # 绘制匹配线段
                    if frame_count > 0:
                        match = point_matching(fingertips, last_points, slot=SLOT_NUMBERS)
                        ang_set = [-1] * SLOT_NUMBERS
                        dis_set = [-1] * SLOT_NUMBERS
                        for i in range(len(match)):
                            point_sta = fingertips[match[i][0]]
                            point_end = last_points[match[i][1]]
                            distance = mu.get_distance(point_sta, point_end)
                            angle = mu.get_angle(point_sta, point_end)
                            if distance > MIN_DISTANCE and angle != -1:
                                ang_set[match[i][0]] = round(angle / (2 * math.pi), 2)
                                dis_set[match[i][0]] = round(distance, 2)
                                if ang_set.count(-1) < SLOT_NUMBERS - 1:
                                    ang_list.append(ang_set)
                                    dis_list.append(dis_set)
                                    print(ang_set)
                                '''
                                p1 = tuple((point_sta[0] * RESCALE_RATE, point_sta[1] * RESCALE_RATE))
                                p2 = tuple((point_end[0] * RESCALE_RATE, point_end[1] * RESCALE_RATE))
                                cv2.line(copy, p1, p2, (0, 255, 255))
                                '''
                    last_points = fingertips
                    frame_count += 1
                    '''
                    cv2.imshow('show_img', max_skin)
                    if cv2.waitKey(1) & 0xff == ord('n'):
                        continue
                    '''
                # 文件存储
                out_file = open(OUT_FILE, 'a+')
                for i in range(5):
                    for j in range(len(ang_list)):
                        out_file.write(' ' + str(angle_to_number(ang_list[j][i])))
                    out_file.write(';')
                out_file.write(str(file))
                out_file.write('\n')
                out_file.close()
                video_capture.release()
    cv2.destroyAllWindows()
